{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3fafb7f6",
   "metadata": {},
   "source": [
    "# Exploring Dependency Groups"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0ffcf934",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "import sys, os\n",
    "sys.path.append(os.path.abspath(\"../\"))\n",
    "\n",
    "import torch\n",
    "from torchvision.models import resnet18\n",
    "import torch_pruning as tp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93f4f573",
   "metadata": {},
   "source": [
    "### Grouping\n",
    "\n",
    "In this part, we will delve into the details in the ``DependencyGraph`` module, illustrating its effectiveness in facilitating structural pruning. First, let's fetch a group from a ResNet-18."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fea4aadd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 0. prepare your model and example inputs\n",
    "model = resnet18(pretrained=True).eval()\n",
    "example_inputs = torch.randn(1,3,224,224)\n",
    "\n",
    "# 1. build dependency graph for resnet18\n",
    "DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)\n",
    "\n",
    "# 2. Select some channels to prune. Here we prune the channels indexed by [2, 6, 9].\n",
    "pruning_idxs = pruning_idxs=[2, 6, 9]\n",
    "group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b719e024",
   "metadata": {},
   "source": [
    "### Indexing\n",
    "\n",
    "In Torch-Pruning, dependencies are organized as a iteratable list. In a given group, the initial operation performed by users is considered the root operation. For instance, if we are attempting to prune the 'model.conv1' operation, the first dependency in the group will reflect this action, which prunes the output channels of conv1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d61396ca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "([DEP] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), [2, 6, 9])\n"
     ]
    }
   ],
   "source": [
    "print(group[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "692b3d64",
   "metadata": {},
   "source": [
    "Each dependency in the group will include pruning indices that correspond to the channels to be pruned. Here we aims to remove the 2nd, 6th and 9th channels of conv1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fbbfbf9d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dep: [DEP] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))\n",
      "Indices: [2, 6, 9]\n"
     ]
    }
   ],
   "source": [
    "print(\"Dep:\", group[1][0])\n",
    "print(\"Indices:\", group[1][1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d47a65cf",
   "metadata": {},
   "source": [
    "Let's delve deeper into the concept of dependency in DepGraph. In DepGraph a dependency is represented as an edge that connects two nodes, indicating the presence of inter-dependency. Each dependency maintains two pruning functions: 1) a trigger function, which is a pruning operation that breaks the dependency when solely applied, and 2) a handler function, which can repair the broken dependency caused by triggers.  \n",
    "\n",
    "For instance, consider the simple Conv-BN dependency between conv1 and bn1. If we remove an output channel of 'conv1', it becomes necessary to prune the corresponding channel of 'BN' as well. This dependency is clearly illustrated in the following example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ee248dc5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Source Node: Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
      "Target Node: BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "Trigger Function: <bound method ConvPruner.prune_out_channels of <torch_pruning.pruner.function.ConvPruner object at 0x7f5cbe611b20>>\n",
      "Handler Function: <bound method BatchnormPruner.prune_out_channels of <torch_pruning.pruner.function.BatchnormPruner object at 0x7f5cbe611bb0>>\n"
     ]
    }
   ],
   "source": [
    "print(\"Source Node:\", group[1][0].source.module) # group[1][0].source.module # get the nn.Module\n",
    "print(\"Target Node:\", group[1][0].target.module) # group[1][0].target.module # get the nn.Module\n",
    "print(\"Trigger Function:\", group[1][0].trigger)\n",
    "print(\"Handler Function:\", group[1][0].handler)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b7f8d24",
   "metadata": {},
   "source": [
    "### Pruning with Dependency\n",
    "\n",
    "In Torch-Pruning, we can \"execute\" a dependency to apply the handler function for pruning. Here we only prune the first conv1 without fixing the dependency."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "102e0de5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx = group[0][1]\n",
    "dep = group[0][0]\n",
    "dep(idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e4c8270",
   "metadata": {},
   "source": [
    "However, if we try to forward this model as usual, a error will occur, showing that ``running_mean should contain 61 elements not 64.``"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a0a3e191",
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "running_mean should contain 61 elements not 64",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn [7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrandn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m224\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m224\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1129\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.9/site-packages/torchvision/models/resnet.py:285\u001b[0m, in \u001b[0;36mResNet.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    284\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 285\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_forward_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.9/site-packages/torchvision/models/resnet.py:269\u001b[0m, in \u001b[0;36mResNet._forward_impl\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    266\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_forward_impl\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m    267\u001b[0m     \u001b[38;5;66;03m# See note [TorchScript super()]\u001b[39;00m\n\u001b[1;32m    268\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv1(x)\n\u001b[0;32m--> 269\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbn1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    270\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrelu(x)\n\u001b[1;32m    271\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmaxpool(x)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1129\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:168\u001b[0m, in \u001b[0;36m_BatchNorm.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    161\u001b[0m     bn_training \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrunning_mean \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrunning_var \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m    163\u001b[0m \u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    164\u001b[0m \u001b[38;5;124;03mBuffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be\u001b[39;00m\n\u001b[1;32m    165\u001b[0m \u001b[38;5;124;03mpassed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are\u001b[39;00m\n\u001b[1;32m    166\u001b[0m \u001b[38;5;124;03mused for normalization (i.e. in eval mode when buffers are not None).\u001b[39;00m\n\u001b[1;32m    167\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m--> 168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_norm\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    169\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    170\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;66;43;03m# If buffers are not to be tracked, ensure that they won't be updated\u001b[39;49;00m\n\u001b[1;32m    171\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrunning_mean\u001b[49m\n\u001b[1;32m    172\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrack_running_stats\u001b[49m\n\u001b[1;32m    173\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    174\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrunning_var\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrack_running_stats\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    175\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    176\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    177\u001b[0m \u001b[43m    \u001b[49m\u001b[43mbn_training\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    178\u001b[0m \u001b[43m    \u001b[49m\u001b[43mexponential_average_factor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    179\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    180\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.9/site-packages/torch/nn/functional.py:2438\u001b[0m, in \u001b[0;36mbatch_norm\u001b[0;34m(input, running_mean, running_var, weight, bias, training, momentum, eps)\u001b[0m\n\u001b[1;32m   2435\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m training:\n\u001b[1;32m   2436\u001b[0m     _verify_batch_size(\u001b[38;5;28minput\u001b[39m\u001b[38;5;241m.\u001b[39msize())\n\u001b[0;32m-> 2438\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_norm\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2439\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrunning_mean\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrunning_var\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmomentum\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackends\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcudnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menabled\u001b[49m\n\u001b[1;32m   2440\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: running_mean should contain 61 elements not 64"
     ]
    }
   ],
   "source": [
    "print(model(torch.randn(1,3,224,224)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10117ec4",
   "metadata": {},
   "source": [
    "To address this issue, we should use \"group pruning\" to remove a group of parameters from this model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4d2da737",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = resnet18(pretrained=True).eval()\n",
    "example_inputs = torch.randn(1,3,224,224)\n",
    "\n",
    "# 1. build dependency graph for resnet18\n",
    "DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)\n",
    "\n",
    "# 2. Select some channels to prune. Here we prune the channels indexed by [2, 6, 9].\n",
    "pruning_idxs = pruning_idxs=[2, 6, 9]\n",
    "group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )\n",
    "\n",
    "group.prune()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "45a926e2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 1000])\n"
     ]
    }
   ],
   "source": [
    "print(model(torch.randn(1,3,224,224)).shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
